{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Unit test writing using a multi-step prompt\n",
    "\n",
    "Complex tasks, such as writing unit tests, can benefit from multi-step prompts. In contrast to a single prompt, a multi-step prompt generates text from GPT and then feeds that output text back into subsequent prompts. This can help in cases where you want GPT to reason things out before answering, or brainstorm a plan before executing it.\n",
    "\n",
    "In this notebook, we use a 3-step prompt to write unit tests in Python using the following steps:\n",
    "\n",
    "1. **Explain**: Given a Python function, we ask GPT to explain what the function is doing and why.\n",
    "2. **Plan**: We ask GPT to plan a set of unit tests for the function.\n",
    "    - If the plan is too short, we ask GPT to elaborate with more ideas for unit tests.\n",
    "3. **Execute**: Finally, we instruct GPT to write unit tests that cover the planned cases.\n",
    "\n",
    "The code example illustrates a few embellishments on the chained, multi-step prompt:\n",
    "\n",
    "- Conditional branching (e.g., asking for elaboration only if the first plan is too short)\n",
    "- The choice of different models for different steps\n",
    "- A check that re-runs the function if the output is unsatisfactory (e.g., if the output code cannot be parsed by Python's `ast` module)\n",
    "- Streaming output so that you can start reading the output before it's fully generated (handy for long, multi-step outputs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "# imports needed to run the code in this notebook\n",
    "import ast  # used for detecting whether generated Python code is valid\n",
    "import os\n",
    "from openai import OpenAI\n",
    "\n",
    "client = OpenAI(api_key=os.environ.get(\"OPENAI_API_KEY\", \"<your OpenAI API key if not set as env var>\"))\n",
    "\n",
    "color_prefix_by_role = {\n",
    "    \"system\": \"\\033[0m\",  # gray\n",
    "    \"user\": \"\\033[0m\",  # gray\n",
    "    \"assistant\": \"\\033[92m\",  # green\n",
    "}\n",
    "\n",
    "\n",
    "def print_messages(messages, color_prefix_by_role=color_prefix_by_role) -> None:\n",
    "    \"\"\"Prints messages sent to or from GPT.\"\"\"\n",
    "    for message in messages:\n",
    "        role = message[\"role\"]\n",
    "        color_prefix = color_prefix_by_role[role]\n",
    "        content = message[\"content\"]\n",
    "        print(f\"{color_prefix}\\n[{role}]\\n{content}\")\n",
    "\n",
    "\n",
    "def print_message_delta(delta, color_prefix_by_role=color_prefix_by_role) -> None:\n",
    "    \"\"\"Prints a chunk of messages streamed back from GPT.\"\"\"\n",
    "    if \"role\" in delta:\n",
    "        role = delta[\"role\"]\n",
    "        color_prefix = color_prefix_by_role[role]\n",
    "        print(f\"{color_prefix}\\n[{role}]\\n\", end=\"\")\n",
    "    elif \"content\" in delta:\n",
    "        content = delta[\"content\"]\n",
    "        print(content, end=\"\")\n",
    "    else:\n",
    "        pass\n",
    "\n",
    "\n",
    "# example of a function that uses a multi-step prompt to write unit tests\n",
    "def unit_tests_from_function(\n",
    "    function_to_test: str,  # Python function to test, as a string\n",
    "    unit_test_package: str = \"pytest\",  # unit testing package; use the name as it appears in the import statement\n",
    "    approx_min_cases_to_cover: int = 7,  # minimum number of test case categories to cover (approximate)\n",
    "    print_text: bool = False,  # optionally prints text; helpful for understanding the function & debugging\n",
    "    explain_model: str = \"gpt-3.5-turbo\",  # model used to generate text plans in step 1\n",
    "    plan_model: str = \"gpt-3.5-turbo\",  # model used to generate text plans in steps 2 and 2b\n",
    "    execute_model: str = \"gpt-3.5-turbo\",  # model used to generate code in step 3\n",
    "    temperature: float = 0.4,  # temperature = 0 can sometimes get stuck in repetitive loops, so we use 0.4\n",
    "    reruns_if_fail: int = 1,  # if the output code cannot be parsed, this will re-run the function up to N times\n",
    ") -> str:\n",
    "    \"\"\"Returns a unit test for a given Python function, using a 3-step GPT prompt.\"\"\"\n",
    "\n",
    "    # Step 1: Generate an explanation of the function\n",
    "\n",
    "    # create a markdown-formatted message that asks GPT to explain the function, formatted as a bullet list\n",
    "    explain_system_message = {\n",
    "        \"role\": \"system\",\n",
    "        \"content\": \"You are a world-class Python developer with an eagle eye for unintended bugs and edge cases. You carefully explain code with great detail and accuracy. You organize your explanations in markdown-formatted, bulleted lists.\",\n",
    "    }\n",
    "    explain_user_message = {\n",
    "        \"role\": \"user\",\n",
    "        \"content\": f\"\"\"Please explain the following Python function. Review what each element of the function is doing precisely and what the author's intentions may have been. Organize your explanation as a markdown-formatted, bulleted list.\n",
    "\n",
    "```python\n",
    "{function_to_test}\n",
    "```\"\"\",\n",
    "    }\n",
    "    explain_messages = [explain_system_message, explain_user_message]\n",
    "    if print_text:\n",
    "        print_messages(explain_messages)\n",
    "\n",
    "    explanation_response = client.chat.completions.create(model=explain_model,\n",
    "    messages=explain_messages,\n",
    "    temperature=temperature,\n",
    "    stream=True)\n",
    "    explanation = \"\"\n",
    "    for chunk in explanation_response:\n",
    "        delta = chunk.choices[0].delta\n",
    "        if print_text:\n",
    "            print_message_delta(delta)\n",
    "        if \"content\" in delta:\n",
    "            explanation += delta.content\n",
    "    explain_assistant_message = {\"role\": \"assistant\", \"content\": explanation}\n",
    "\n",
    "    # Step 2: Generate a plan to write a unit test\n",
    "\n",
    "    # Asks GPT to plan out cases the units tests should cover, formatted as a bullet list\n",
    "    plan_user_message = {\n",
    "        \"role\": \"user\",\n",
    "        \"content\": f\"\"\"A good unit test suite should aim to:\n",
    "- Test the function's behavior for a wide range of possible inputs\n",
    "- Test edge cases that the author may not have foreseen\n",
    "- Take advantage of the features of `{unit_test_package}` to make the tests easy to write and maintain\n",
    "- Be easy to read and understand, with clean code and descriptive names\n",
    "- Be deterministic, so that the tests always pass or fail in the same way\n",
    "\n",
    "To help unit test the function above, list diverse scenarios that the function should be able to handle (and under each scenario, include a few examples as sub-bullets).\"\"\",\n",
    "    }\n",
    "    plan_messages = [\n",
    "        explain_system_message,\n",
    "        explain_user_message,\n",
    "        explain_assistant_message,\n",
    "        plan_user_message,\n",
    "    ]\n",
    "    if print_text:\n",
    "        print_messages([plan_user_message])\n",
    "    plan_response = client.chat.completions.create(model=plan_model,\n",
    "    messages=plan_messages,\n",
    "    temperature=temperature,\n",
    "    stream=True)\n",
    "    plan = \"\"\n",
    "    for chunk in plan_response:\n",
    "        delta = chunk.choices[0].delta\n",
    "        if print_text:\n",
    "            print_message_delta(delta)\n",
    "        if \"content\" in delta:\n",
    "            explanation += delta.content\n",
    "    plan_assistant_message = {\"role\": \"assistant\", \"content\": plan}\n",
    "\n",
    "    # Step 2b: If the plan is short, ask GPT to elaborate further\n",
    "    # this counts top-level bullets (e.g., categories), but not sub-bullets (e.g., test cases)\n",
    "    num_bullets = max(plan.count(\"\\n-\"), plan.count(\"\\n*\"))\n",
    "    elaboration_needed = num_bullets < approx_min_cases_to_cover\n",
    "    if elaboration_needed:\n",
    "        elaboration_user_message = {\n",
    "            \"role\": \"user\",\n",
    "            \"content\": f\"\"\"In addition to those scenarios above, list a few rare or unexpected edge cases (and as before, under each edge case, include a few examples as sub-bullets).\"\"\",\n",
    "        }\n",
    "        elaboration_messages = [\n",
    "            explain_system_message,\n",
    "            explain_user_message,\n",
    "            explain_assistant_message,\n",
    "            plan_user_message,\n",
    "            plan_assistant_message,\n",
    "            elaboration_user_message,\n",
    "        ]\n",
    "        if print_text:\n",
    "            print_messages([elaboration_user_message])\n",
    "        elaboration_response = client.chat.completions.create(model=plan_model,\n",
    "        messages=elaboration_messages,\n",
    "        temperature=temperature,\n",
    "        stream=True)\n",
    "        elaboration = \"\"\n",
    "        for chunk in elaboration_response:\n",
    "            delta = chunk.choices[0].delta\n",
    "        if print_text:\n",
    "            print_message_delta(delta)\n",
    "        if \"content\" in delta:\n",
    "            explanation += delta.content\n",
    "        elaboration_assistant_message = {\"role\": \"assistant\", \"content\": elaboration}\n",
    "\n",
    "    # Step 3: Generate the unit test\n",
    "\n",
    "    # create a markdown-formatted prompt that asks GPT to complete a unit test\n",
    "    package_comment = \"\"\n",
    "    if unit_test_package == \"pytest\":\n",
    "        package_comment = \"# below, each test case is represented by a tuple passed to the @pytest.mark.parametrize decorator\"\n",
    "    execute_system_message = {\n",
    "        \"role\": \"system\",\n",
    "        \"content\": \"You are a world-class Python developer with an eagle eye for unintended bugs and edge cases. You write careful, accurate unit tests. When asked to reply only with code, you write all of your code in a single block.\",\n",
    "    }\n",
    "    execute_user_message = {\n",
    "        \"role\": \"user\",\n",
    "        \"content\": f\"\"\"Using Python and the `{unit_test_package}` package, write a suite of unit tests for the function, following the cases above. Include helpful comments to explain each line. Reply only with code, formatted as follows:\n",
    "\n",
    "```python\n",
    "# imports\n",
    "import {unit_test_package}  # used for our unit tests\n",
    "{{insert other imports as needed}}\n",
    "\n",
    "# function to test\n",
    "{function_to_test}\n",
    "\n",
    "# unit tests\n",
    "{package_comment}\n",
    "{{insert unit test code here}}\n",
    "```\"\"\",\n",
    "    }\n",
    "    execute_messages = [\n",
    "        execute_system_message,\n",
    "        explain_user_message,\n",
    "        explain_assistant_message,\n",
    "        plan_user_message,\n",
    "        plan_assistant_message,\n",
    "    ]\n",
    "    if elaboration_needed:\n",
    "        execute_messages += [elaboration_user_message, elaboration_assistant_message]\n",
    "    execute_messages += [execute_user_message]\n",
    "    if print_text:\n",
    "        print_messages([execute_system_message, execute_user_message])\n",
    "\n",
    "    execute_response = client.chat.completions.create(model=execute_model,\n",
    "        messages=execute_messages,\n",
    "        temperature=temperature,\n",
    "        stream=True)\n",
    "    execution = \"\"\n",
    "    for chunk in execute_response:\n",
    "        delta = chunk.choices[0].delta\n",
    "        if print_text:\n",
    "            print_message_delta(delta)\n",
    "        if delta.content:\n",
    "            execution += delta.content\n",
    "\n",
    "    # check the output for errors\n",
    "    code = execution.split(\"```python\")[1].split(\"```\")[0].strip()\n",
    "    try:\n",
    "        ast.parse(code)\n",
    "    except SyntaxError as e:\n",
    "        print(f\"Syntax error in generated code: {e}\")\n",
    "        if reruns_if_fail > 0:\n",
    "            print(\"Rerunning...\")\n",
    "            return unit_tests_from_function(\n",
    "                function_to_test=function_to_test,\n",
    "                unit_test_package=unit_test_package,\n",
    "                approx_min_cases_to_cover=approx_min_cases_to_cover,\n",
    "                print_text=print_text,\n",
    "                explain_model=explain_model,\n",
    "                plan_model=plan_model,\n",
    "                execute_model=execute_model,\n",
    "                temperature=temperature,\n",
    "                reruns_if_fail=reruns_if_fail\n",
    "                - 1,  # decrement rerun counter when calling again\n",
    "            )\n",
    "\n",
    "    # return the unit test as a string\n",
    "    return code\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[0m\n",
      "[system]\n",
      "You are a world-class Python developer with an eagle eye for unintended bugs and edge cases. You carefully explain code with great detail and accuracy. You organize your explanations in markdown-formatted, bulleted lists.\n",
      "\u001b[0m\n",
      "[user]\n",
      "Please explain the following Python function. Review what each element of the function is doing precisely and what the author's intentions may have been. Organize your explanation as a markdown-formatted, bulleted list.\n",
      "\n",
      "```python\n",
      "def pig_latin(text):\n",
      "    def translate(word):\n",
      "        vowels = 'aeiou'\n",
      "        if word[0] in vowels:\n",
      "            return word + 'way'\n",
      "        else:\n",
      "            consonants = ''\n",
      "            for letter in word:\n",
      "                if letter not in vowels:\n",
      "                    consonants += letter\n",
      "                else:\n",
      "                    break\n",
      "            return word[len(consonants):] + consonants + 'ay'\n",
      "\n",
      "    words = text.lower().split()\n",
      "    translated_words = [translate(word) for word in words]\n",
      "    return ' '.join(translated_words)\n",
      "\n",
      "```\n",
      "\u001b[0m\n",
      "[user]\n",
      "A good unit test suite should aim to:\n",
      "- Test the function's behavior for a wide range of possible inputs\n",
      "- Test edge cases that the author may not have foreseen\n",
      "- Take advantage of the features of `pytest` to make the tests easy to write and maintain\n",
      "- Be easy to read and understand, with clean code and descriptive names\n",
      "- Be deterministic, so that the tests always pass or fail in the same way\n",
      "\n",
      "To help unit test the function above, list diverse scenarios that the function should be able to handle (and under each scenario, include a few examples as sub-bullets).\n",
      "\u001b[0m\n",
      "[user]\n",
      "In addition to those scenarios above, list a few rare or unexpected edge cases (and as before, under each edge case, include a few examples as sub-bullets).\n",
      "\u001b[0m\n",
      "[system]\n",
      "You are a world-class Python developer with an eagle eye for unintended bugs and edge cases. You write careful, accurate unit tests. When asked to reply only with code, you write all of your code in a single block.\n",
      "\u001b[0m\n",
      "[user]\n",
      "Using Python and the `pytest` package, write a suite of unit tests for the function, following the cases above. Include helpful comments to explain each line. Reply only with code, formatted as follows:\n",
      "\n",
      "```python\n",
      "# imports\n",
      "import pytest  # used for our unit tests\n",
      "{insert other imports as needed}\n",
      "\n",
      "# function to test\n",
      "def pig_latin(text):\n",
      "    def translate(word):\n",
      "        vowels = 'aeiou'\n",
      "        if word[0] in vowels:\n",
      "            return word + 'way'\n",
      "        else:\n",
      "            consonants = ''\n",
      "            for letter in word:\n",
      "                if letter not in vowels:\n",
      "                    consonants += letter\n",
      "                else:\n",
      "                    break\n",
      "            return word[len(consonants):] + consonants + 'ay'\n",
      "\n",
      "    words = text.lower().split()\n",
      "    translated_words = [translate(word) for word in words]\n",
      "    return ' '.join(translated_words)\n",
      "\n",
      "\n",
      "# unit tests\n",
      "# below, each test case is represented by a tuple passed to the @pytest.mark.parametrize decorator\n",
      "{insert unit test code here}\n",
      "```\n",
      "execute messages: [{'role': 'system', 'content': 'You are a world-class Python developer with an eagle eye for unintended bugs and edge cases. You write careful, accurate unit tests. When asked to reply only with code, you write all of your code in a single block.'}, {'role': 'user', 'content': \"Please explain the following Python function. Review what each element of the function is doing precisely and what the author's intentions may have been. Organize your explanation as a markdown-formatted, bulleted list.\\n\\n```python\\ndef pig_latin(text):\\n    def translate(word):\\n        vowels = 'aeiou'\\n        if word[0] in vowels:\\n            return word + 'way'\\n        else:\\n            consonants = ''\\n            for letter in word:\\n                if letter not in vowels:\\n                    consonants += letter\\n                else:\\n                    break\\n            return word[len(consonants):] + consonants + 'ay'\\n\\n    words = text.lower().split()\\n    translated_words = [translate(word) for word in words]\\n    return ' '.join(translated_words)\\n\\n```\"}, {'role': 'assistant', 'content': ''}, {'role': 'user', 'content': \"A good unit test suite should aim to:\\n- Test the function's behavior for a wide range of possible inputs\\n- Test edge cases that the author may not have foreseen\\n- Take advantage of the features of `pytest` to make the tests easy to write and maintain\\n- Be easy to read and understand, with clean code and descriptive names\\n- Be deterministic, so that the tests always pass or fail in the same way\\n\\nTo help unit test the function above, list diverse scenarios that the function should be able to handle (and under each scenario, include a few examples as sub-bullets).\"}, {'role': 'assistant', 'content': ''}, {'role': 'user', 'content': 'In addition to those scenarios above, list a few rare or unexpected edge cases (and as before, under each edge case, include a few examples as sub-bullets).'}, {'role': 'assistant', 'content': ''}, {'role': 'user', 'content': \"Using Python and the `pytest` package, write a suite of unit tests for the function, following the cases above. Include helpful comments to explain each line. Reply only with code, formatted as follows:\\n\\n```python\\n# imports\\nimport pytest  # used for our unit tests\\n{insert other imports as needed}\\n\\n# function to test\\ndef pig_latin(text):\\n    def translate(word):\\n        vowels = 'aeiou'\\n        if word[0] in vowels:\\n            return word + 'way'\\n        else:\\n            consonants = ''\\n            for letter in word:\\n                if letter not in vowels:\\n                    consonants += letter\\n                else:\\n                    break\\n            return word[len(consonants):] + consonants + 'ay'\\n\\n    words = text.lower().split()\\n    translated_words = [translate(word) for word in words]\\n    return ' '.join(translated_words)\\n\\n\\n# unit tests\\n# below, each test case is represented by a tuple passed to the @pytest.mark.parametrize decorator\\n{insert unit test code here}\\n```\"}]\n"
     ]
    }
   ],
   "source": [
    "example_function = \"\"\"def pig_latin(text):\n",
    "    def translate(word):\n",
    "        vowels = 'aeiou'\n",
    "        if word[0] in vowels:\n",
    "            return word + 'way'\n",
    "        else:\n",
    "            consonants = ''\n",
    "            for letter in word:\n",
    "                if letter not in vowels:\n",
    "                    consonants += letter\n",
    "                else:\n",
    "                    break\n",
    "            return word[len(consonants):] + consonants + 'ay'\n",
    "\n",
    "    words = text.lower().split()\n",
    "    translated_words = [translate(word) for word in words]\n",
    "    return ' '.join(translated_words)\n",
    "\"\"\"\n",
    "\n",
    "unit_tests = unit_tests_from_function(\n",
    "    example_function,\n",
    "    approx_min_cases_to_cover=10,\n",
    "    print_text=True\n",
    ")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "# imports\n",
      "import pytest\n",
      "\n",
      "# function to test\n",
      "def pig_latin(text):\n",
      "    def translate(word):\n",
      "        vowels = 'aeiou'\n",
      "        if word[0] in vowels:\n",
      "            return word + 'way'\n",
      "        else:\n",
      "            consonants = ''\n",
      "            for letter in word:\n",
      "                if letter not in vowels:\n",
      "                    consonants += letter\n",
      "                else:\n",
      "                    break\n",
      "            return word[len(consonants):] + consonants + 'ay'\n",
      "\n",
      "    words = text.lower().split()\n",
      "    translated_words = [translate(word) for word in words]\n",
      "    return ' '.join(translated_words)\n",
      "\n",
      "\n",
      "# unit tests\n",
      "@pytest.mark.parametrize('text, expected', [\n",
      "    ('hello world', 'ellohay orldway'),  # basic test case\n",
      "    ('Python is awesome', 'ythonPay isway awesomeway'),  # test case with multiple words\n",
      "    ('apple', 'appleway'),  # test case with a word starting with a vowel\n",
      "    ('', ''),  # test case with an empty string\n",
      "    ('123', '123'),  # test case with non-alphabetic characters\n",
      "    ('Hello World!', 'elloHay orldWay!'),  # test case with punctuation\n",
      "    ('The quick brown fox', 'ethay ickquay ownbray oxfay'),  # test case with mixed case words\n",
      "    ('a e i o u', 'away eway iway oway uway'),  # test case with all vowels\n",
      "    ('bcd fgh jkl mnp', 'bcday fghay jklway mnpay'),  # test case with all consonants\n",
      "])\n",
      "def test_pig_latin(text, expected):\n",
      "    assert pig_latin(text) == expected\n"
     ]
    }
   ],
   "source": [
    "print(unit_tests)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Make sure to check any code before using it, as GPT makes plenty of mistakes (especially on character-based tasks like this one). For best results, use the most powerful model (GPT-4, as of May 2023)."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.9.9 ('openai')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.3"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "365536dcbde60510dc9073d6b991cd35db2d9bac356a11f5b64279a5e6708b97"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
